{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bdd1f95a",
   "metadata": {},
   "source": [
    "# 变换概述\n",
    "\n",
    "{mod}`tvm.ir.transform` 定义了 IR 变体之间的通用 pass 的基础设施。\n",
    "\n",
    "````{tab-set}\n",
    "```{tab-item} PassInfo\n",
    "{class}`tvm.ir.transform.PassInfo` 类包含 pass 所需的元数据。它是运行优化或分析所需信息的容器。当需要更多元数据时，可以通过添加新成员来扩展这个类。\n",
    "\n",
    "- `name` （`str`）是 pass 名称\n",
    "- `opt_level` （``int``） 表示在哪个优化级别将启用传递\n",
    "- `required` （`list[str]`） 表示执行某个传递所需的依赖\n",
    "```\n",
    "```{tab-item} Pass\n",
    "所有 Pass 的基类。这里的所有方法都只是在后端实现的简单包装器。它们的定义是为了方便用户与基类进行交互。\n",
    "```\n",
    "```{tab-item} PassContext\n",
    "{class}`tvm.ir.transform.PassContext` 表示优化/分析运行的基础。\n",
    "\n",
    "每个 pass 上下文都包含许多辅助信息，用于帮助优化 pass。这些信息包括记录优化过程中误差的误差报告器等。\n",
    "```\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4178bb92",
   "metadata": {},
   "source": [
    "Relax/tir 程序的优化可以应用在不同的粒度上，即\n",
    "- 函数级 {class}`tvm.relax.transform.FunctionPass` / {class}`tvm.tir.transform.PrimFuncPass` \n",
    "- 模块级 {class}`tvm.transform.ModulePass`。\n",
    "- 用户可以依赖于 {class}`tvm.transform.Sequential` 在 Relax/tir 程序上应用 pass 序列，其中 pass 之间的依赖性可以由 pass infra 解析。{class}`~tvm.ir.transform.Sequential` 处理 pass 对象序列的传递。可以使用这个类顺序地执行多个传递。请注意，用户还可以提供一系列在运行顺序传递时不希望应用的传递。pass 依赖项也将在后端进行解析。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f9d899d",
   "metadata": {},
   "source": [
    "## 模块级 Pass\n",
    "\n",
    "{class}`tvm.ir.transform.ModulePass` 是在 {class}`~tvm.ir.module.IRModule` 上工作的 pass。用户不需要直接与该类交互。相反，应该通过 {func}`tvm.ir.transform.module_pass` 创建模块级传递，因为 `module_pass` API 的设计足够灵活，以不同的方式处理模块级 pass 的创建。此外，可以从基类访问模块 pass 的所有成员。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "588249b1",
   "metadata": {},
   "source": [
    "### 类模式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3eb6633c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tvm\n",
    "from tvm import relax, tir\n",
    "from tvm.script import relax as R\n",
    "\n",
    "@tvm.transform.module_pass(opt_level=2)\n",
    "class CustomPipeline:\n",
    "    def __init__(self, enable_fold):\n",
    "        self.enable_fold = enable_fold\n",
    "        self.cse = relax.transform.EliminateCommonSubexpr()\n",
    "        self.const_fold = relax.transform.FoldConstant()\n",
    "\n",
    "    def transform_module(self, mod, ctx):\n",
    "        mod = self.cse(mod)\n",
    "        if self.enable_fold:\n",
    "            mod = self.const_fold(mod)\n",
    "        return mod\n",
    "\n",
    "# 创建定制的 pipeline 实例\n",
    "pipeline = CustomPipeline(enable_fold=False)\n",
    "assert isinstance(pipeline, tvm.transform.ModulePass)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f5f03475",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #A2F\">@R</span><span style=\"color: #A2F; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">main</span>(x: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;num_input&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
       "        c: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>]\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>dataflow():\n",
       "            conv: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>nn<span style=\"color: #A2F; font-weight: bold\">.</span>conv2d(x, weight, strides<span style=\"color: #A2F; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #A2F; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #A2F; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            y: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(c, c)\n",
       "            y_1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>multiply(y, R<span style=\"color: #A2F; font-weight: bold\">.</span>const(<span style=\"color: #008000\">2.0</span>, <span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "            y_2: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(conv, y_1)\n",
       "            z: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(y_2, c)\n",
       "            z1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(y_2, c)\n",
       "            z2: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(z, z1)\n",
       "            gv: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> z2\n",
       "            R<span style=\"color: #A2F; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #A2F\">@R</span><span style=\"color: #A2F; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">main</span>(x: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">56</span>, <span style=\"color: #008000\">56</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;num_input&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>dataflow():\n",
       "            conv: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>nn<span style=\"color: #A2F; font-weight: bold\">.</span>conv2d(x, weight, strides<span style=\"color: #A2F; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #A2F; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #A2F; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            y: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(conv, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>])\n",
       "            z: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(y, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">1</span>])\n",
       "            z2: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(z, z)\n",
       "            gv: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">54</span>, <span style=\"color: #008000\">54</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> z2\n",
       "            R<span style=\"color: #A2F; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tvm.script import ir as I\n",
    "from tvm.script import relax as R\n",
    "\n",
    "@I.ir_module\n",
    "class Model:\n",
    "    @R.function\n",
    "    def main(\n",
    "        x: R.Tensor((1, 64, 56, 56), dtype=\"float32\"),\n",
    "        weight: R.Tensor((64, 64, 3, 3), dtype=\"float32\"),\n",
    "    ) -> R.Tensor((1, 64, 54, 54), dtype=\"float32\"):\n",
    "        R.func_attr({\"num_input\": 1})\n",
    "        c_data = np.empty((1, 64, 54, 54)).astype(\"float32\")\n",
    "        c = R.const(c_data)\n",
    "        with R.dataflow():\n",
    "            conv = R.nn.conv2d(x, weight)\n",
    "            y = R.add(c, c)\n",
    "            y = R.multiply(y, R.const(2, \"float32\"))\n",
    "            y = R.add(conv, y)\n",
    "            z = R.add(y, c)\n",
    "            z1 = R.add(y, c)\n",
    "            z2 = R.add(z, z1)\n",
    "            gv = z2\n",
    "            R.output(gv)\n",
    "        return gv\n",
    "\n",
    "m = Model\n",
    "m.show()\n",
    "pipeline = CustomPipeline(enable_fold=True)\n",
    "pipeline(m).show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52c80327",
   "metadata": {},
   "source": [
    "### 函数模式"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72732ed9",
   "metadata": {},
   "source": [
    "以下代码通过装饰用户定义的变换函数来创建模块传递。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69cd6e9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.script import ir as I\n",
    "from tvm.script import relax as R\n",
    "\n",
    "@tvm.transform.module_pass(opt_level=2)\n",
    "def transform(mod, ctx):\n",
    "    @I.ir_module\n",
    "    class Model:\n",
    "        @R.function\n",
    "        def main(\n",
    "            x: R.Tensor((1, 2), dtype=\"float32\"),\n",
    "        ) -> R.Tensor((1, 2), dtype=\"float32\"):\n",
    "            with R.dataflow():\n",
    "                y = R.abs(x)\n",
    "                gv = y\n",
    "                R.output(gv)\n",
    "            return gv\n",
    "    new_mod = tvm.IRModule()\n",
    "    new_mod['abs'] = Model[\"main\"]\n",
    "    new_mod.update(mod)\n",
    "    return new_mod\n",
    "\n",
    "module_pass = transform\n",
    "assert isinstance(module_pass, tvm.transform.ModulePass)\n",
    "assert module_pass.info.opt_level == 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "419c18a8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #A2F\">@R</span><span style=\"color: #A2F; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">abs</span>(x: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>})\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>dataflow():\n",
       "            y: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>abs(x)\n",
       "            gv: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> y\n",
       "            R<span style=\"color: #A2F; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "    <span style=\"color: #A2F\">@R</span><span style=\"color: #A2F; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">main</span>(x: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>dataflow():\n",
       "            y: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(x, x)\n",
       "            gv: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> y\n",
       "            R<span style=\"color: #A2F; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tvm.script import ir as I\n",
    "from tvm.script import relax as R\n",
    "\n",
    "@I.ir_module\n",
    "class Model:\n",
    "    @R.function\n",
    "    def main(\n",
    "        x: R.Tensor((1, 2,), dtype=\"float32\"),\n",
    "    ) -> R.Tensor((1, 2,), dtype=\"float32\"):\n",
    "        with R.dataflow():\n",
    "            y = R.add(x, x)\n",
    "            gv = y\n",
    "            R.output(gv)\n",
    "        return gv\n",
    "\n",
    "# 给定模块 `m`，优化可以如下调用：\n",
    "m = Model\n",
    "updated_mod = module_pass(m)\n",
    "# 现在，函数 `abs` 应该被添加到模块 `m` 中。\n",
    "updated_mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7c53df1",
   "metadata": {},
   "source": [
    "## Relax 函数级 Pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89a10788",
   "metadata": {},
   "source": [
    "{func}`tvm.relax.transform.function_pass` 用于变换 Relax 函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "304b8918",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py313",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
